support weight-update in disaggregated mode using sglang#1766
support weight-update in disaggregated mode using sglang#1766PengchengShi00 wants to merge 4 commits into
Conversation
| @@ -1155,5 +1158,7 @@ async def _sync_weights_and_save(self, train_step: int, step_timer_dict: dict): | |||
| self.fake_update_weights() | |||
|
|
|||
| def fake_update_weights(self): | |||
There was a problem hiding this comment.
rename function, if this is a real udpate weight operation, please remove fake
There was a problem hiding this comment.
Updated in the latest commit: fake_update_weights has been renamed to update_weights.
| REPO_ROOT = Path(__file__).resolve().parents[2] | ||
| if str(REPO_ROOT) not in sys.path: | ||
| sys.path.insert(0, str(REPO_ROOT)) | ||
| TEST_DIR = Path(__file__).resolve().parent | ||
| if str(TEST_DIR) not in sys.path: | ||
| sys.path.insert(0, str(TEST_DIR)) |
There was a problem hiding this comment.
It's unnecessary in ci ut test.
|
|
||
| def pause_generation(self): | ||
| return self._make_request("pause_generation") | ||
| return self._make_request("pause_generation", {"mode": "retract"}) |
There was a problem hiding this comment.
add comment explaining extra param.
There was a problem hiding this comment.
have added a comment in the latest commit.
SGLang PauseGeneration supports three modes:
abort: drop both waiting and running requestsretract: keep waiting/running requests and generated tokens, release KV cache, and recompute KV on resumein_place: keep waiting/running requests, generated tokens, and KV cache, and resume directly
I also changed the mode to abort. Before update_weights, send_abort_request has already been issued, so there should be no pending requests to preserve. In this case, abort is sufficient and makes the intended behavior clearer.
| ray.get([worker.reset_update_weight_sha256.remote() for worker in train_controller.workers]) | ||
| train_controller.update_weights() | ||
| first_hashes = ray.get([worker.get_update_weight_sha256.remote() for worker in train_controller.workers]) | ||
|
|
||
| ray.get([worker.reset_update_weight_sha256.remote() for worker in train_controller.workers]) | ||
| train_controller.update_weights() | ||
| second_hashes = ray.get([worker.get_update_weight_sha256.remote() for worker in train_controller.workers]) |
There was a problem hiding this comment.
This hash logic is about verifying the deterministic operation of update weight. But hashable result should be used when comparing training side sent state_dict and rollout side received state_dict.
There was a problem hiding this comment.
In the latest commit, I added per-bucket hash checks for both the training-side sent state_dict and the rollout-side received state_dict.
To support this comparison, I also needed a small SGLang-side change so the rollout side can return the received bucket hash. I have rebuilt the docker image with that SGLang patch applied.
In the latest commit, the unit tests now verify:
- The rollout output remains unchanged for the same input before and after the weight update.
- For each bucket, the training-side sent state_dict hash matches the rollout-side received state_dict hash.
- Across two consecutive weight updates, the training-side sent bucket state_dict hashes remain identical.
| self.request_update_params(state_dict, train_enable_ep=train_enable_ep, finished=False) | ||
| del state_dict, name_list, param_list | ||
|
|
||
| if self.rollout_cfg_info["backend"] == "pytorch" and final_update: |
There was a problem hiding this comment.
| if self.rollout_cfg_info["backend"] == "pytorch" and final_update: | |
| if self.rollout_cfg_info["backend"] in ("pytorch", "vllm") and final_update: |
| ) | ||
|
|
||
| @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") | ||
| def test_lmdeploy_update_weight_and_generate(self): |
There was a problem hiding this comment.
it's not a disaggregate case, and duplicated of the case in tests/rl/test_update_weight.py, it should be removed or forced skipping for now.
| from xtuner.v1.utils import ray_method | ||
|
|
||
| TEST_TEXT_MESSAGES = [{"role": "user", "content": "Hello!"}] | ||
| MODEL_PATH = os.environ.get("MODEL_PATH") or os.environ.get("QWEN3_VL_DENSE_PATH") |
There was a problem hiding this comment.
should remove os.environ.get("MODEL_PATH") here, just use original CI ENV
| ) | ||
|
|
||
| # training config | ||
| model_cfg = get_model_config_from_hf(Path(MODEL_PATH)) |
There was a problem hiding this comment.
maybe use specific model here for using QWEN3_VL_DENSE_PATH, see tests/rl/test_update_weight.py
| def setUpClass(cls) -> None: | ||
| if MODEL_PATH is None: | ||
| raise unittest.SkipTest("MODEL_PATH is not set") | ||
| os.environ["XTUNER_USE_FA3"] = "1" |
There was a problem hiding this comment.
NCCL_CUMEM_ENABLE=0 is required in my test environment, we may add it here or in the Actor creation stage by runtime_env of ray.remote
There was a problem hiding this comment.
NCCL_CUMEM_ENABLE=0 is also required in my test environment. The latest commit has already added this environment variable in the unit tests.
|
BTW, fix CI lint. |
a. 创建训练 ranks 之间使用的 gloo group,训推分离权重同步时通过该group做 barrier
b. 创建了一个 NCCL process group,用来将训练 rank0 把 bucket 后的权重 broadcast 给 SGLang rollout ranks: